import os

import torch
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
import random
import xformers

import matplotlib.pyplot as plt
import math
import re


def _memory_efficient_attention_xformers(module, query, key, value):
    query = query.contiguous()
    key = key.contiguous()
    value = value.contiguous()
    hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
    hidden_states = module.batch_to_head_dim(hidden_states)
    print('hidden_states', hidden_states.size())
    return hidden_states


class RegionalGenerator:
    def __init__(self, model_id, dtype=torch.float32, device="cuda"):

        self.tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer')
        self.text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder='text_encoder').eval().to(device,
                                                                                                        dtype=dtype)
        self.vae = AutoencoderKL.from_pretrained(model_id, subfolder='vae').eval().to(device, dtype=dtype)
        self.vae.enable_slicing()

        self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder='unet').eval().to(device, dtype=dtype)
        self.unet.set_use_memory_efficient_attention_xformers(True)

        self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

        self.dtype = dtype
        self.device = device

        self.hook_forwards(
            self.unet)


    # Transform to Pillow
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.cpu().permute(0, 2, 3, 1).float().numpy()
        images = (images * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        return pil_images

    def __call__(
            self,
            prompts,
            negative_prompt,
            batch_size=1,
            pos = "left",
            height: int = 512,
            width: int = 512,
            guidance_scale: float = 7.0,
            num_inference_steps: int = 50,
            seed=42,
            base_ratio=0.3,
            end_steps: float = 1,
    ):
        '''
        prompts: base prompt + regional prompt
        '''
        if (seed >= 0):
            self.torch_fix_seed(seed=seed)

        self.base_ratio = base_ratio

        all_prompts = []
        for prompt in prompts:
            all_prompts.extend([prompt] * batch_size)
        all_prompts.extend([negative_prompt] * batch_size)

        text_embs = self.encode_prompts(all_prompts)

        # set timestep
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # Intialize the noise [batch_size, 4, height // 8, width // 8] common shape in SD hidden states
        latents = torch.randn(batch_size, 4, height // 8, width // 8).to(self.device, dtype=self.dtype)
        latents = latents * self.scheduler.init_noise_sigma

        self.height = height // 8
        self.width = width // 8
        self.pixels = self.height * self.width

        progress_bar = tqdm(range(num_inference_steps), desc="Total Steps", leave=False)

        self.double = True
        for i, t in enumerate(timesteps):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # attention_double version ending condition
            if i > num_inference_steps * end_steps and self.double:
                print(i)
                cond, _, _, negative = text_embs.chunk(4)  # cond, left, right, negative
                text_embs = torch.cat([cond, negative])
                self.double = False

            # predict noise
            with torch.no_grad():
                noise_pred = self.unet(sample=latent_model_input, timestep=t, encoder_hidden_states=text_embs, pos=pos).sample

            # negative CFG
            noise_pred_text, noise_pred_negative = noise_pred.chunk(2)
            noise_pred = noise_pred_negative + guidance_scale * (noise_pred_text - noise_pred_negative)

            # Get denoised latents
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

            progress_bar.update(1)
            if i % 10 == 0:
                cur_img = self.decode_latents(latents)[0]
                plt.figure(figsize=(20, 20))
                plt.imshow(np.array(cur_img))
                plt.title(f"Diffusion Step {i}")
                plt.axis('off')
                plt.show()

        images = self.decode_latents(latents)
        return images

    def hook_forward(self, module):
        def forward(hidden_states, encoder_hidden_states=None, pos="left", attention_mask=None):
            context = encoder_hidden_states
            # 4, 51, 1024
            batch_size, sequence_length, _ = hidden_states.shape

            query = module.to_q(hidden_states)

            # copy query
            if self.double:
                # (q_cond, q_uncond) -> (q_cond,q_cond,q_cond,q_uncond)
                query_cond, query_uncond = query.chunk(2)
                query = torch.cat([query_cond, query_cond, query_cond, query_uncond])

            context = context if context is not None else hidden_states
            key = module.to_k(context) # 4*51*320
            # print(key[0], key[1])
            value = module.to_v(context)

            dim = query.shape[-1]

            query = module.head_to_batch_dim(query)
            key = module.head_to_batch_dim(key)
            value = module.head_to_batch_dim(value)

            # attention, what we cannot get enough of
            # if module._use_memory_efficient_attention_xformers:
            hidden_states = _memory_efficient_attention_xformers(module, query, key, value)
            # Some versions of xformers return output in fp32, cast it back to the dtype of the input
            hidden_states = hidden_states.to(query.dtype)
            if self.double:
                rate = int((self.pixels // query.shape[1]) ** 0.5)  # down sample rate



                if pos == "left":
                # reshape to the image shape
                    height = self.height // rate
                    width = self.width // rate

                    cond, left, right, uncond = hidden_states.chunk(4)
                    left = left.reshape(left.shape[0], height, width, left.shape[2])
                    right = right.reshape(right.shape[0], height, width, right.shape[2])

                    # combine

                    double = torch.cat([left[:, :, :width // 2, :], right[:, :, width // 2:, :]], dim=2)
                else:
                    height = self.height // rate
                    width = self.width // rate
                    cond, top, bot, uncond = hidden_states.chunk(4)
                    top = top.reshape(top.shape[0], height, width, top.shape[2])
                    bot = bot.reshape(bot.shape[0], height, width, bot.shape[2])
                    # combine
                    double = torch.cat([top[:, :height // 2, :, :], bot[:, height // 2:, :, :]], dim=1)



                double = double.reshape(cond.shape[0], -1, cond.shape[2])

                # weighted sum
                cond = double * (1 - self.base_ratio) + cond * self.base_ratio

                # cond+uncond
                hidden_states = torch.cat([cond, uncond])

            hidden_states = module.to_out[0](hidden_states)
            hidden_states = module.to_out[1](hidden_states)

            return hidden_states

        return forward

    # rewrite unet forward()
    def hook_forwards(self, root_module: torch.nn.Module):
        # cnt = 0
        for name, module in root_module.named_modules():
            if "attn2" in name and module.__class__.__name__ == "Attention":

                print(f'{name}:{module.__class__.__name__}')
                module.forward = self.hook_forward(module)

    # set random seed
    def torch_fix_seed(self, seed=42):
        # Python random
        random.seed(seed)
        # Numpy
        np.random.seed(seed)
        # Pytorch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.use_deterministic_algorithms = True


def parse_tagged_text(input_text):

    paragraphs = [p.strip().strip('"') for p in input_text.strip().split('\n\n') if p.strip()]

    agent_map = {}  # To normalize agent names
    cleaned_paragraphs = []
    positions = []

    for paragraph in paragraphs:
        working_text = paragraph

        pos_match = re.search(r'\[pos: ([^\]]+)\]', working_text)
        if pos_match:
            positions.append(pos_match.group(1).strip())
            working_text = re.sub(r'\[pos: [^\]]+\]', '', working_text)

        # Find all agents in this paragraph
        agent_matches = re.findall(r'\[agent: ([^\]]+)\]', working_text)
        for agent in agent_matches:
            # Normalize agent names by capitalizing first letter of each word
            normalized_agent = ' '.join(word.capitalize() for word in agent.strip().split())
            agent_map[agent.strip()] = normalized_agent

        # Clean the text - remove ALL tag types
        clean_text = re.sub(r'\[[^\]]+: ([^\]]+)\]', r'\1', working_text)
        clean_text = re.sub(r'\[[^\]:]+:([^\]]+)\]', r'\1', clean_text)
        clean_text = ' '.join(clean_text.split())

        cleaned_paragraphs.append(clean_text)

    unique_agents = sorted(list(set(agent_map.values())))

    agents_string = ", ".join(unique_agents)
    descriptions_string = " ".join(cleaned_paragraphs)
    positions_string = ", ".join(positions)

    return [agents_string, descriptions_string, [positions_string]]



def main():


    model_id = "stabilityai/stable-diffusion-2-1"

    torch.cuda.set_device(0)
    pipe = RegionalGenerator(model_id,dtype = torch.float16)

    model_save_path = "model.pt"
    model = torch.load(model_save_path)
    data_dir = "data"
    loaded_data = np.load(f"{data_dir}/data.npz",
                          allow_pickle=True)  # allow_pickle is necessary for object arrays
    test_data = loaded_data['test_fmri']
    pred_text = model.generate_sentences(test_data)
    pred_text = parse_tagged_text(pred_text)
    prompt = [pred_text[0], pred_text[1]]
    pos = pred_text[2][0]




    negative_prompt = "worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
    # negative_prompt = ""


    images = pipe(prompt, negative_prompt,
                  batch_size = 1, #batch size
                  pos = pos,
                  num_inference_steps=40, # sampling step
                  height = 896,
                  width = 640,
                  end_steps = 1,
                  base_ratio=0.4,
                  seed = 42
    )



    plt.figure(figsize=(20,20))
    for i, image in enumerate(images):
        plt.subplot(math.ceil(len(images)/4),4,i+1)
        plt.imshow(np.array(image))
        plt.axis('off')
    plt.show()

    # save the images
    img_dir = "saved_imgs"
    for i, image in enumerate(images):
        image.save(f"{img_dir}/generated_image_{i}.png")


if __name__ == "__main__":
    main()
